今天要來利用python的sklearn來實作decision tree,今天的範例同樣是要預測sklearn裡的iris資料集。同樣的第一步,導入iris資料集的特徵(feature)為x、標籤(label)為y。完成後的x和y如下:
from sklearn import datasets
import pandas as pd
data = datasets.load_iris()
x = pd.DataFrame(data["data"], columns = data["feature_names"])
y = pd.DataFrame(data["target"], columns = ["target"])
接著要建構模型,這邊直接利用sklearn裡面的函式建構,建構完成後直接將x和y丟進去訓練。
from sklearn import tree
decisionTree = tree.DecisionTreeClassifier()
decisionTree.fit(x, y)
訓練完成過後,我們可以利用sklearn內建的函式—plot_tree來看看我們所做出的decision tree模型。
tree.plot_tree(decisionTree)
如果想要將圖形給儲存下來,可以利用python的graphviz套件,他可以像下面一樣輸出一個.dot檔。
import graphviz
with open("decision_tree.dot", "w") as f:
tree.export_graphviz(decisionTree, out_file = f)
最後利用下方的指令就可以獲得一個pdf檔囉!
!dot -Tpdf decision_tree.dot -o decision_tree.pdf